{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b48a5bf-c945-4304-98d8-9c5d061d2540",
   "metadata": {},
   "outputs": [],
   "source": [
    "from diffusers import DiffusionPipeline\n",
    "from diffusers.callbacks import PipelineCallback\n",
    "import matplotlib.pyplot as plt\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f453935-ee11-4d1c-8424-6faf7baf72aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "def display_slider_images(images, titles):\n",
    "    fig, axes = plt.subplots(1, len(images), figsize=(len(images)*3, 3))\n",
    "    \n",
    "    for i, (img, title) in enumerate(zip(images, titles)):\n",
    "        if len(images) == 1:\n",
    "            ax = axes\n",
    "        else:\n",
    "            ax = axes[i]\n",
    "        ax.imshow(img)\n",
    "        ax.set_title(title)\n",
    "        ax.axis('off')\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "430fe187-43da-4de6-a56c-9920e86bc9b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ConceptSliderCallback(PipelineCallback):\n",
    "    \"\"\"\n",
    "    Enable Concept Slider after certain number of steps (set by `slider_strength`), this callback will set the LoRA scale to `0.0` or `slider_scale` based on the strength.\n",
    "\n",
    "    Use strength < 1 if you want more precise edits (recommend: .7 - .9)\n",
    "    \"\"\"\n",
    "    tensor_inputs = []\n",
    "\n",
    "    def __init__(self, slider_strength=1, slider_names=None, slider_scales=[0]):\n",
    "        super().__init__()\n",
    "        self.slider_names = slider_names\n",
    "        self.slider_scales = slider_scales\n",
    "        self.slider_strength = slider_strength\n",
    "    \n",
    "    def callback_fn(self, pipeline, step_index, timestep, callback_kwargs):\n",
    "        # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio\n",
    "        attach_step = (\n",
    "           pipeline.num_timesteps - int(pipeline.num_timesteps * self.slider_strength)\n",
    "        )\n",
    "\n",
    "\n",
    "        # at the attach_step point start adding the slider\n",
    "        if step_index == attach_step:\n",
    "            pipe.set_adapters(self.slider_names, adapter_weights=self.slider_scales)\n",
    "\n",
    "        # after final step set the slider to 0 (there is a better implementation if we  callback_at_beginning of step exists in diffusers) \n",
    "        if step_index == pipeline.num_timesteps-1:\n",
    "            pipe.set_adapters(self.slider_names, adapter_weights=[0.]*len(self.slider_names))\n",
    "        \n",
    "        return callback_kwargs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7edf5967-7664-4ebe-b2e2-6137b421e2c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "pipe = DiffusionPipeline.from_pretrained(\"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.bfloat16)\n",
    "pipe.to(\"cuda\").to(torch.bfloat16)\n",
    "\n",
    "# you can use your trained slider (either .pt or .safetensors file with diffusers)\n",
    "adapter_path = 'models/sdxl-conceptslider-age.safetensors'\n",
    "adapter_name = 'age'\n",
    "\n",
    "pipe.load_lora_weights(adapter_path, adapter_name=adapter_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da032a9b-fdac-4b98-9f49-8b82d0b7d5bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# If the strength is not working try to reset your slider to scale=0 (better implementation needed)\n",
    "pipe.set_adapters(adapter_name, adapter_weights=0)\n",
    "\n",
    "prompt = \"image of a person, realistic, 8k\"\n",
    "\n",
    "negative_prompt = None\n",
    "slider_scales = [-5, -2.5, 0, 2.5, 5]\n",
    "seed = 10\n",
    "\n",
    "\n",
    "images = []\n",
    "\n",
    "\n",
    "\n",
    "for scale in slider_scales:\n",
    "    \n",
    "    sliders_fn = ConceptSliderCallback(slider_strength=.9, \n",
    "                                        slider_names=[adapter_name], \n",
    "                                        slider_scales=[scale])\n",
    "    \n",
    "    image = pipe(\n",
    "        prompt, \n",
    "        negative_prompt= negative_prompt,\n",
    "        guidance_scale=7,\n",
    "        num_inference_steps=20, \n",
    "        generator=torch.manual_seed(seed),\n",
    "        callback_on_step_end=sliders_fn,\n",
    "    ).images[0]\n",
    "    images.append(image)\n",
    "\n",
    "display_slider_images(images, slider_scales)\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
